{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MnistDataset(Dataset):\n",
    "    def __init__(self, csv_file) -> None:\n",
    "        self.data_df = pd.read_csv(csv_file, header=None)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        label = self.data_df.iloc[index, 0]\n",
    "        targets = torch.zeros(10)\n",
    "        targets[label] = 1.0\n",
    "        img_values = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0\n",
    "        return label, img_values, targets\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.data_df)\n",
    "\n",
    "    def plot_image(self, index):\n",
    "        arr = self.data_df.iloc[index, 1:].values.reshape(28, 28)\n",
    "        plt.title(f'label = {self.data_df.iloc[index, 0]}')\n",
    "        plt.imshow(arr, interpolation='none', cmap='Blues')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_dataset = MnistDataset('mnist_dataset/mnist_train.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPQUlEQVR4nO3dfZBd9V3H8fcHCpXHQro7MaRpQiG0A7QG3KZFGKRSIWQcAz4whBGjMA21oK2CI+If4MMo1UJl1IKhQUKhQEfKwIwoYOqIWAtZmACB1ELTBJLmYbcUQkSkIV//uCdlWfaeu7n33Hvu5vt5zdzZs+f3O/d898An597z9FNEYGZ7v33qLsDMesNhN0vCYTdLwmE3S8JhN0vCYTdLwmGfgiStl/TJSfYNSUe3uZ62l7X+47BbLSTNlfS6pNvqriULh93q8nfAqrqLyMRhn+IkzZf0X5JelrRZ0t9K2n9ct4WS1kkalfRXkvYZs/yFktZK+qGkByTN7kHN5wEvAyu7vS57i8M+9b0J/C4wAJwEnA58Zlyfc4Ah4ERgEXAhgKRFwJXALwGDwH8Ad0xmpZK+VPwDM9HrqZLlDgX+BPi9PfgbrQIO+xQXEY9HxLciYmdErAf+HvjZcd0+HxEvRcQLwF8Di4v5nwb+IiLWRsRO4M+BeZPZu0fEZyLisCavj5Qs+qfA8ojYuId/qnXoXXUXYJ2RdAxwHY0994E0/ps+Pq7bi2OmNwBHFNOzgeslXTv2LYGZRb+qa50HfBI4oer3tta8Z5/6bgC+DcyNiENpfCzXuD6zxky/H/h+Mf0icPG4vfIBEfHNViuVdKOkHU1ezzRZ7DRgDvCCpC3A5cAvS3pisn+stc9hn/oOAbYDOyR9CPitCfr8vqTDJc0CPgvcVcy/EfhDSccBSHqPpF+dzEoj4tMRcXCT13FNFlsGHAXMK143Av8EnDmpv9Q64rBPfZcD5wOvAjfxVpDHupfGR/vVNMK1HCAi7gE+D9wpaTuwBjirW4VGxGsRsWX3C9gBvB4RI91ap71FfniFWQ7es5sl4bCbJeGwmyXhsJsl0dOLagYGBmL27Dm9XKVZKhs2rGd0dHT8dRZAh2GXtAC4HtgX+HJEXFPWf/bsOfzno8OdrNLMSpz8saGmbW1/jJe0L43bFM8CjgUWSzq23fczs+7q5Dv7fOD5iFgXEW8Ad9K4o8rM+lAnYZ/J22+w2FjMextJSyUNSxoeGfWFUmZ16frR+IhYFhFDETE0ODDY7dWZWROdhH0Tb7+b6n3FPDPrQ52EfRUwV9KRxWOQzgPuq6YsM6ta26feImKnpEuBB2icers5Iprdx2xmNevoPHtE3A/cX1EtZtZFvlzWLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SwJh90sCYfdLAmH3SyJjkZxNetnj617qWnbmb+zonTZ1bdcXNo+e+DAtmqqU0dhl7QeeBV4E9gZEUNVFGVm1atiz/6JiBit4H3MrIv8nd0siU7DHsCDkh6XtHSiDpKWShqWNDwyOtLh6sysXZ2G/ZSIOBE4C7hE0qnjO0TEsogYioihwYHBDldnZu3qKOwRsan4uQ24B5hfRVFmVr22wy7pIEmH7J4GzgDWVFWYmVWrk6Px04F7JO1+n69GxL9UUlUXrF7/cmn71tdeL20/89ifrLAa64X7n2t+jOjoj364h5X0h7bDHhHrgJ+qsBYz6yKfejNLwmE3S8JhN0vCYTdLwmE3SyLNLa5fXbO5tP3J7zW/HRJ86q0f7doVpe1Pv/hK07bvv1B+6XZE+XtPRd6zmyXhsJsl4bCbJeGwmyXhsJsl4bCbJeGwmyWR5jz78tu/Vdo+9DPH9KgSq8rojjdK27+x7CtN2z726+eVLjtn8KC2aupn3rObJeGwmyXhsJsl4bCbJeGwmyXhsJsl4bCbJZHmPPuuXbvqLsEqtuAL/972svOOnFZhJVOD9+xmSTjsZkk47GZJOOxmSTjsZkk47GZJOOxmSew159mf37KjvMOWdS3e4UOV1WK9sf2V19pe9tc+ckSFlUwNLffskm6WtE3SmjHzpkl6SNJzxc/Du1ummXVqMh/jbwEWjJt3BbAyIuYCK4vfzayPtQx7RDwMjB8baRGwopheAZxdbVlmVrV2D9BNj4jdg6dtAaY36yhpqaRhScMjo+Xja5lZ93R8ND4aI+A1HQUvIpZFxFBEDA0ODHa6OjNrU7th3yppBkDxc1t1JZlZN7Qb9vuAJcX0EuDeasoxs25peZ5d0h3AacCApI3AVcA1wNckXQRsAM7tZpGTcduTm8o7/O/23hRilflBi+fC/2Dd+rbfe/DQd7e97FTVMuwRsbhJ0+kV12JmXeTLZc2ScNjNknDYzZJw2M2ScNjNkthrbnFdtW785ft75qRjBiqqxKpy/j+sKu+w9bulzft98KNN2w7cf992SprSvGc3S8JhN0vCYTdLwmE3S8JhN0vCYTdLwmE3S2KvOc/eqYVz/RSddvzP6ztL27/xfPPnmvzZPWtLl/3OAw+0VdNuX7r855q2HXLAfh2991TkPbtZEg67WRIOu1kSDrtZEg67WRIOu1kSDrtZEj7PXtj22uu1rfu5FsNNNwbdae4fn93ctO3bm8vf+/9+9GZp+4O33V/azq7y5Tng0KZNcz8+r3zZdx9U3r6z/FHTQ0d4cOGxvGc3S8JhN0vCYTdLwmE3S8JhN0vCYTdLwmE3S2KvOc9+8E+0+FNU/u/aBZetKH//2R/Y05ImbceT3yzvELvK29+1f/O2Aw8rXfS9x324tH3Rhb9Y2v4Lx5c/B+DUOc2fxz/t4JK6gcFHnyxtbzUM95zBFufpk2m5Z5d0s6RtktaMmXe1pE2SVhevhd0t08w6NZmP8bcACyaY/8WImFe8WlxmZWZ1axn2iHgY6GxsJTOrXScH6C6V9FTxMb/pRciSlkoaljQ8MjrSwerMrBPthv0G4ChgHrAZuLZZx4hYFhFDETE0OOCHOprVpa2wR8TWiHgzInYBNwHzqy3LzKrWVtglzRjz6znAmmZ9zaw/tDzPLukO4DRgQNJG4CrgNEnzgADWAxd3r8TJues3m4/FDfDHM5vfVw3wz49trLKcPfP+8nPZlyw4urT9hOmHNW07ftZ72qmoJ7786PfKO4ysL23e56gTqysmgZZhj4jFE8xe3oVazKyLfLmsWRIOu1kSDrtZEg67WRIOu1kSe80trq1cdcYHO2q36t3+yAsdLf8r5/x0RZXk4D27WRIOu1kSDrtZEg67WRIOu1kSDrtZEg67WRJpzrPb3ue3T5pddwlTivfsZkk47GZJOOxmSTjsZkk47GZJOOxmSTjsZkk47GZJOOxmSTjsZkk47GZJOOxmSTjsZkk47GZJOOxmSUxmyOZZwK3AdBpDNC+LiOslTQPuAubQGLb53Ij4YfdKtXQiSpufGX2ltL2fh6uuw2T27DuByyLiWODjwCWSjgWuAFZGxFxgZfG7mfWplmGPiM0R8UQx/SqwFpgJLAJWFN1WAGd3qUYzq8AefWeXNAc4AXgUmB4Rm4umLTQ+5ptZn5p02CUdDNwNfC4ito9ti4ig8X1+ouWWShqWNDwyOtJRsWbWvkmFXdJ+NIJ+e0R8vZi9VdKMon0GsG2iZSNiWUQMRcTQ4MBgFTWbWRtahl2SgOXA2oi4bkzTfcCSYnoJcG/15ZlZVSbzKOmTgQuApyWtLuZdCVwDfE3SRcAG4NyuVGh5SaXNb7Y4NWdv1zLsEfEI0Gyrn15tOWbWLb6CziwJh90sCYfdLAmH3SwJh90sCYfdLAkP2WxT1t1PbC1tP/9ED+k8lvfsZkk47GZJOOxmSTjsZkk47GZJOOxmSTjsZkn4PLv1L9+vXinv2c2ScNjNknDYzZJw2M2ScNjNknDYzZJw2M2S8Hl2q82nPjGntP2Su8qfG297xnt2syQcdrMkHHazJBx2syQcdrMkHHazJBx2syRanmeXNAu4FZgOBLAsIq6XdDXwKWCk6HplRNzfrUJt79Pque7nP/Y3Paokh8lcVLMTuCwinpB0CPC4pIeKti9GxBe6V56ZVaVl2CNiM7C5mH5V0lpgZrcLM7Nq7dF3dklzgBOAR4tZl0p6StLNkg5vssxSScOShkdGRybqYmY9MOmwSzoYuBv4XERsB24AjgLm0djzXzvRchGxLCKGImJocGCw84rNrC2TCruk/WgE/faI+DpARGyNiDcjYhdwEzC/e2WaWadahl2SgOXA2oi4bsz8GWO6nQOsqb48M6vKZI7GnwxcADwtaXUx70pgsaR5NE7HrQcu7kJ9ZlaRyRyNfwSY6MZin1M3m0J8BZ1ZEg67WRIOu1kSDrtZEg67WRIOu1kSDrtZEg67WRIOu1kSDrtZEg67WRIOu1kSDrtZEg67WRKKiN6tTBoBNoyZNQCM9qyAPdOvtfVrXeDa2lVlbbMjYsLnv/U07O9YuTQcEUO1FVCiX2vr17rAtbWrV7X5Y7xZEg67WRJ1h31Zzesv06+19Wtd4Nra1ZPaav3Obma9U/ee3cx6xGE3S6KWsEtaIOm/JT0v6Yo6amhG0npJT0taLWm45lpulrRN0pox86ZJekjSc8XPCcfYq6m2qyVtKrbdakkLa6ptlqR/k/SspGckfbaYX+u2K6mrJ9ut59/ZJe0LfAf4eWAjsApYHBHP9rSQJiStB4YiovYLMCSdCuwAbo2I44t5fwm8FBHXFP9QHh4Rf9AntV0N7Kh7GO9itKIZY4cZB84GfoMat11JXefSg+1Wx559PvB8RKyLiDeAO4FFNdTR9yLiYeClcbMXASuK6RU0/mfpuSa19YWI2BwRTxTTrwK7hxmvdduV1NUTdYR9JvDimN830l/jvQfwoKTHJS2tu5gJTI+IzcX0FmB6ncVMoOUw3r00bpjxvtl27Qx/3ikfoHunUyLiROAs4JLi42pfisZ3sH46dzqpYbx7ZYJhxn+szm3X7vDnnaoj7JuAWWN+f18xry9ExKbi5zbgHvpvKOqtu0fQLX5uq7meH+unYbwnGmacPth2dQ5/XkfYVwFzJR0paX/gPOC+Gup4B0kHFQdOkHQQcAb9NxT1fcCSYnoJcG+NtbxNvwzj3WyYcWredrUPfx4RPX8BC2kckf8u8Ed11NCkrg8ATxavZ+quDbiDxse6H9E4tnER8F5gJfAc8K/AtD6q7SvA08BTNII1o6baTqHxEf0pYHXxWlj3tiupqyfbzZfLmiXhA3RmSTjsZkk47GZJOOxmSTjsZkk47GZJOOxmSfw/riy/z94biRwAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "mnist_dataset.plot_image(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    \"\"\"鉴别器\"\"\"\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        \n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(784, 100),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(100, 1),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "\n",
    "        self.loss_function = nn.MSELoss()\n",
    "\n",
    "        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)\n",
    "\n",
    "        self.counter = 0\n",
    "        self.progress = []\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        return self.model(inputs)\n",
    "\n",
    "    def train(self, inputs, targets):\n",
    "        outputs = self.forward(inputs)\n",
    "        loss = self.loss_function(outputs, targets)\n",
    "        self.optimiser.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimiser.step()\n",
    "\n",
    "        self.counter += 1\n",
    "        if self.counter % 10 == 0:\n",
    "            self.progress.append(loss.item())\n",
    "        if self.counter % 10000 == 0:\n",
    "            print(f'counter: {self.counter}')\n",
    "\n",
    "    def plot_progress(self):\n",
    "        df = pd.DataFrame(self.progress, columns=['loss'])\n",
    "        df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8151, 0.4552, 0.3282, 0.1781])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def generate_random(size):\n",
    "    return torch.rand(size)\n",
    "generate_random(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "counter: 10000\n",
      "counter: 20000\n",
      "counter: 30000\n",
      "counter: 40000\n",
      "counter: 50000\n",
      "counter: 60000\n",
      "counter: 70000\n",
      "counter: 80000\n",
      "counter: 90000\n",
      "counter: 100000\n",
      "counter: 110000\n",
      "counter: 120000\n",
      "Wall time: 1min 46s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "D = Discriminator()\n",
    "\n",
    "for label, image_data_tensor, targets in mnist_dataset:\n",
    "    D.train(image_data_tensor, torch.FloatTensor([1.0]))\n",
    "    D.train(generate_random(784), torch.FloatTensor([0.0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA6gAAAHSCAYAAADhZ+amAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAlrUlEQVR4nO3de5ClZ2Hn999zzunLdPfcJMGARjIStmxsQwBbgDeOhYxtYDdbYJe9ZVx7EV4wSWqxN5stp0xRsR3HLq+tVMhu1jFLebHR1tpSlpCNsuGy2FgFroBXIMRVxghZwAyjy4w09+nLOefJH30ktcbTo74c6Tw9+nyquua873nP6ae7n3p7vv2+5z2l1hoAAACYtM6kBwAAAACJQAUAAKARAhUAAIAmCFQAAACaIFABAABogkAFAACgCb1JD+B8V1xxRb3mmmsmPYyLOnPmTObn5yc9DBpkbnAx5gfrMTdYj7nBeswNLqb1+fGZz3zmaK31ORe6r7lAveaaa/LpT3960sO4qDvuuCM33njjpIdBg8wNLsb8YD3mBusxN1iPucHFtD4/SilfX+8+p/gCAADQBIEKAABAEwQqAAAATWjuNagAAADPJisrKzl06FAWFxfH8nx79+7NPffcM5bn2o7Z2dlcddVVmZqa2vBjBCoAAMAEHTp0KLt3784111yTUsq2n+/UqVPZvXv3GEa2dbXWHDt2LIcOHcq111674cc5xRcAAGCCFhcXc/nll48lTltRSsnll1++6aPCAhUAAGDCLqU4fcxWviaBCgAA8Cy3sLAw6SEkEagAAAA0QqACAADsMMv9Yc4s9bPcH471eWut+cVf/MW8+MUvzkte8pLcdtttSZIjR47khhtuyMte9rK8+MUvzic+8YkMBoO8+c1vfnzbd73rXdv+/K7iCwAA0IiTiyvpD+pFt1kZDPOt4+dSa1JKcuW+XZnqPnHs8fTZlax0lh9f7nVL9sxu7K1ePvCBD+Tuu+/O5z73uRw9ejSveMUrcsMNN+QP//AP87rXvS7vfOc7MxgMcvbs2dx99905fPhwvvjFLyZJjh8/vvkv+DyOoAIAAOwgK4Nhak3mZ7qpdXV5XP7sz/4sP/MzP5Nut5sDBw7k1a9+de6888684hWvyO///u/nV3/1V/OFL3whu3fvzgtf+MLcd999+fmf//l8+MMfzp49e7b9+QUqAABAI/bMTuWy+emLfjx392z2z01lqtvJ/rmpPHf37JPu3z/35OfY6NHTi7nhhhvy8Y9/PAcPHsyb3/zm3HLLLdm/f38+97nP5cYbb8y73/3uvPWtb9325xGoAAAAO8h0r5OD++dyYM9sDu6fy3RvfFn3Qz/0Q7ntttsyGAzy8MMP5+Mf/3he+cpX5utf/3oOHDiQn/u5n8tb3/rW3HXXXTl69GiGw2F+8id/Mr/+67+eu+66a9uf32tQAQAAdpjpXmesYfqYn/iJn8gnP/nJvPSlL00pJb/927+d5z3veXnf+96Xm2++OVNTU1lYWMgtt9ySw4cP52d/9mczHK6eYvybv/mb2/78AhUAAOBZ7vTp00mSUkpuvvnm3HzzzU+6/6abbspNN9301x43jqOmaznFFwAAgCYIVAAAAJogUAEAAGiCQAUAAJiwWuukhzB2W/maBCoAAMAEzc7O5tixY5dUpNZac+zYsczOzm7qca7iCwAAMEFXXXVVDh06lIcffngsz7e4uLjpMHw6zM7O5qqrrtrUYwQqAADABE1NTeXaa68d2/PdcccdefnLXz6253smOcUXAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmbChQSymvL6V8pZRybynlly5w/5tLKQ+XUu4efbx1zX03lVK+Ovq4aZyDBwAA4NLRe6oNSindJL+T5MeSHEpyZynl9lrrl8/b9LZa69vPe+xlSX4lyfVJapLPjB776FhGDwAAwCVjI0dQX5nk3lrrfbXW5SS3JnnjBp//dUk+Wmt9ZBSlH03y+q0NFQAAgEvZUx5BTXIwyTfXLB9K8qoLbPeTpZQbkvxlkn9Sa/3mOo89eP4DSylvS/K2JDlw4EDuuOOODQ1+Uk6fPt38GJkMc4OLMT9Yj7nBeswN1mNucDE7eX5sJFA34v9J8ke11qVSyn+V5H1JXrPRB9da35PkPUly/fXX1xtvvHFMw3p63HHHHWl9jEyGucHFmB+sx9xgPeYG6zE3uJidPD82corv4SRXr1m+arTucbXWY7XWpdHi7yX5/o0+FgAAAJKNBeqdSa4rpVxbSplO8qYkt6/doJTy/DWLb0hyz+j2R5K8tpSyv5SyP8lrR+sAAADgSZ7yFN9aa7+U8vashmU3yXtrrV8qpfxakk/XWm9P8gullDck6Sd5JMmbR499pJTyP2U1cpPk12qtjzwNXwcAAAA73IZeg1pr/WCSD5637pfX3H5Hknes89j3JnnvNsYIAADAs8BGTvEFAACAp51ABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaIJABQAAoAkCFQAAgCYIVAAAAJogUAEAAGiCQAUAAKAJAhUAAIAmCNRNWu4Ps9ivWe4PJz0UAACAS4pA3YTl/jD3PngqD5wZ5vCjZ0UqAADAGAnUTVgZDLPUH6bXSepoGQAAgPEQqJsw1e2k00nO9ZMyWgYAAGA8FNYmTPc6uXLfruybKTm4fy7TPd8+AACAcVFYmzTV7WSmW8QpAADAmKmsTeoPapZcxRcAAGDsBOomLPeH+dbxczm+VF3FFwAAYMw2FKillNeXUr5SSrm3lPJLF7j/vyulfLmU8vlSyp+UUl6w5r5BKeXu0cft4xz8M21lMExNMtNzFV8AAIBx6z3VBqWUbpLfSfJjSQ4lubOUcnut9ctrNvtskutrrWdLKf9Nkt9O8tOj+87VWl823mFPxlS3k1KSRVfxBQAAGLuNFNYrk9xba72v1rqc5NYkb1y7Qa31T2utZ0eLn0py1XiH2YbHruK7dyau4gsAADBmGymsg0m+uWb50Gjdet6S5ENrlmdLKZ8upXyqlPLjmx9iW6Z7ncz0XMUXAABg3Eqt9eIblPJTSV5fa33raPnvJ3lVrfXtF9j27yV5e5JX11qXRusO1loPl1JemORjSX6k1vq18x73tiRvS5IDBw58/6233rr9r+xpcnal5vipM7nysoVJD4UGnT59OgsL5gYXZn6wHnOD9ZgbrMfc4GJanx8//MM//Jla6/UXuu8pX4Oa5HCSq9csXzVa9ySllB9N8s6sidMkqbUeHv17XynljiQvT/KkQK21vifJe5Lk+uuvrzfeeOMGhjUZJxdX8rE7PpGWx8jk3HHHHeYG6zI/WI+5wXrMDdZjbnAxO3l+bOQ81TuTXFdKubaUMp3kTUmedDXeUsrLk/yrJG+otT60Zv3+UsrM6PYVSX4wydqLK+04ZdIDAAAAuEQ95RHUWmu/lPL2JB9J0k3y3lrrl0opv5bk07XW25PcnGQhyb8rpSTJN2qtb0jy3Un+VSllmNUY/mfnXf0XAAAAkmzsFN/UWj+Y5IPnrfvlNbd/dJ3H/X9JXrKdAbamP6hZ7Ncs94culAQAADBGCmsTlvvDfOv4uRxfqjn86Nks94eTHhIAAMAlQ6BuwspgmOXBMKUky4NhVgYCFQAAYFwE6ibUmjxw/GweODXI4UfO5CneoQcAAIBNEKibsDIYZliT6W7JcLQMAADAeAjUTZrqdtPrlvQ63nAGAABgnATqJszP9HJg70xme8nz9sxlfmZDF0EGAABgAwTqJkz3OnnB5fN5zq5Orrli3tvMAAAAjJHC2qSZXjczvZIpcQoAADBWKgsAAIAmCNRNWu4Ps9SvWe67gi8AAMA4CdRNWO4Pc+TEuRw7N8xXHzyZ04v9SQ8JAADgkiFQN2FlMMyZ5ZU8ujTM/UfP5GsPnXIkFQAAYEwE6ibUmnztoVM5fHKQrx87kwdPLubMkqOoAAAA4yBQN6GU5Pl757NnppP987NJyqSHBAAAcMkQqJsw1e3k8oWpLEyX7J2dysF9uzI/05v0sAAAAC4J6moTpnudXHvF7ly7t5vrr92fyxdmMu39UAEAAMZCXW3STK+TmV5Jr+v0XgAAgHESqJu03B/m6NlhvnHsbO4/etpVfAEAAMZEoG7SueV+Hl2qObO0kgdOuIovAADAuHgN6maVkuGw5tzKMFPyHgAAYGwk1iZNdTsZ1JqTZ5bTH9ZMdX0LAQAAxsER1E1aGQzTLSULc9PpdktWBl6DCgAAMA4O/21Bt1PSKclUx5V8AQAAxkWgbtL8TC+7ujVL/WH2zE5nfsZBaAAAgHEQqJu03B/mTD85t9LPI2eXvM0MAADAmAjUTTq73M+wJvt2zWQ4WgYAAGD7nJ+6SVOdTs6s1Hz96Onsn5vOVEfjAwAAjINA3aTpqU6et6vkmufO5zkLM5n2ZqgAAABjIVA3qdZksZacOLWSqZTU5056RAAAAJcGh/826bH3PZ2d6WZY4n1QAQAAxkSgbkGnUzLT63ofVAAAgDFyiu8mzc/0ctlMya6ZkgO7d3kfVAAAgDFRV5s03evkuXOdXH3ZfK7cO5fpnoPQAAAA46CutmCqWzI31ROnAAAAY6SwAAAAaIJA3YL+sObMcj/LfVfwBQAAGBeBuknL/WGOnas5emophx89K1IBAADGRKBu0spgmMGwJiVZ6g+8DyoAAMCYCNRNqjU5dm6Yv3rwdL716NnUOukRAQAAXBoE6iatDIapSWamehnU6ggqAADAmHgf1C0Y1mR2upNay6SHAgAAcMlwBHWT5md6WZhKllYGSamJU3wBAADGwhHULai15v5jZ0bfvJr//Nufm4VZ30oAAIDtcAR1k1Zfg1qytDLM4qDm0KOLOXFuedLDAgAA2PEc9tukqW4nU52aI6eXMjfdTa/jdagAAADjIFA3abrXyZW7e5ld2J1dU53snZ3O3LRvIwAAwHYpqy2Y7ZXsn5tOJ8nsdDdTXWdKAwAAbJdA3ZKS5+2bzcL0VKZ7nRRn+QIAAGybQ39b0Osk091uhjWZ7nYcQQUAABgDR1C3oNcpec7umaz0h3nO7tlM9wQqAADAdgnULegPa75x7ExWBjWLK4Ncd2CPSAUAANgmgboFp5eH+foDpzIz1c3DJ5fynN2zee6e2UkPCwAAYEcTqFswGCbL/ZpSBhl0BlkZDCc9JAAAgB1PoG5Bt1Nz/OxiVgbD7Ns1k6mO03sBAAC2S6BuSScvfO5Cdk31MtPtJd5mBgAAYNsc+tuC2V6yb9d0MkzmpjuZm9b5AAAA26WstmC218mLrt6fR84u59v2z2dh1rcRAABgu5TVFs3P9NIpRZwCAACMiVN8t+jsUj9HTy/l9GJ/0kMBAAC4JDj8twWL/WE+d+h4zi0NcuzUUl7+gsscSQUAANgmR1C3YLGf1Fqyb346gzrM2WVHUQEAALZLoG7BbC/plJrjZ5bTKcVVfAEAAMZAoG7BbK+T6w7szuW7Z/KdB/Y4vRcAAGAMBOoWLPaH+eqDp3L4kTP51H1H89CJxUkPCQAAYMdz6G8LFvtJXVzJ4eOLWVwepNPp5DUvOuBIKgAAwDY4groFvU7NQ6eW8sDJxQxKTacWF0oCAADYJof8tqDX6eQlV+7LueVBFgeD1AxcKAkAAGCbHEHdgl4nmdvVzXId5uzyMOeWh5MeEgAAwI4nULeg1ynplU4eOrGYUpOvPHg6R06cnfSwAAAAdjTnpW5Rp5PUYSfDOszycJiVgaOoAAAA2+EI6hbtnZ1OKcOcPLec1JK9s9OTHhIAAMCO5gjqFk31Ovm+F+zPVLeTXVNTmeppfQAAgO1QVVs0N93L3FQ3S8uDzHTjKr4AAADbJFC34Vx/mHODmpVhnfRQAAAAdjyBukUnzi3nxJmVzHQ7OX52OSfOLU96SAAAADua81K3qibn+v2cWR5ktlcSB1EBAAC2RaBu0dxMLwf37cqj51ayd6bnIkkAAADbpKq2qianlwZ58MRiHji5mG89ei7Lfe+FCgAAsFWOoG7RynCYy+dnM6jJcFhzanE5K4Nhph1JBQAA2BKBukVz072UDPLVIyczLMl0t5MXH9yf+ZlJjwwAAGBncrhvixZme/n2A7tz+e6ZXHvFfDrd5OxKf9LDAgAA2LEcQd2Guelezq30c2Z5JWeXp1zJFwAAYBsE6jasDIY5eXaQlWE/iyuDnF12BBUAAGCrnOK7Df3BMIPBMMNhydmlQe4/dtqVfAEAALbIEdRtmOp28+jSUk6f62duupta40q+AAAAW6SktmFlOExGbzNzZnmQh04upnodKgAAwJYI1G3oD2qme93sm5/KdLeTmV4npUx6VAAAADuTQN2Gy+dncuXeXVlcGWaQmgdOLGV5xWtQAQAAtkKgbsNz9szmlddeloXZ6RxYmM0jZ5Zz9MzipIcFAACwIwnUbZjudXLV/rn0+/08fGoxXzxyMvccOeFKvgAAAFsgULepdJJHzizlyMnFPHj8bL585GSOn1me9LAAAAB2HIG6TeeWhlkeJIsryzm1uJyvHjnlNF8AAIAtEKjbNDfTzb75bhaXkrNLyZePnMgnv/ZwTi/2Jz00AACAHUWgbtPV++dzzf6F9LrJVC9Z7g/zlcOnc+Kc03wBAAA2Q6Bu0/6F6fzgd1yeqU6y2E+Onxvm7kOP5Mjxs5MeGgAAwI4iUMfg2w/szcHL5tNNUpM8cGIxH/nCYaf5AgAAbIJAHYMrFmZz2dxUBkmGSU6uJB/60uF84i+OTHpoAAAAO4ZAHYP52V5e86LnZff0E+sOnaj5H/795/O+T3w1hx85671RAQAAnoJAHYOpbiffdeW+vOjKPU9af3Qx+ZX/9y/z87femQ/ceb9QBQAAuIjepAdwKZjudfKi5+/Na7/3+fnS10/mZH3y/Xd943Tu+sY9uWbPV/Kq7ziQv/HCy/KcPXM5uH8uV+6fS5KsDIaZ6nYy3fM3AwAA4NlpQ4FaSnl9kn+epJvk92qt/+y8+2eS3JLk+5McS/LTtdb7R/e9I8lbkgyS/EKt9SNjG31DFmZ7+fGXf1seOLmUP/jE/Vm6wDb3nxzm/ruO5La7jqSTZH832bcvmZuezt65qVy5Z0++68o9ed7CTPbOzWT3rl727prOFbtnszDrbwkAAMCl7Smrp5TSTfI7SX4syaEkd5ZSbq+1fnnNZm9J8mit9TtKKW9K8ltJfrqU8j1J3pTke5NcmeSPSynfWWsdjPsLacH+hen816++Lntnu3n3R7+WkxfZdpjk2CA5dixJlkcfZ5LPPnFhpekke3rJ/K5kdrakdEpqTUqnpFOSOkxqakopKWuWkzy+LrUkpT7+74W2qcMkJel2SqY73fQ6yeodNb1uNzOdbnrdbnqls/qYJJ3SyWBYs1L7SU1mp3rZPTudmaleOp2afXMzue7AfL7t8oVMdTuZm+7l7Eo/J8+tZM/sVOZmelkZDDM31cvcTO+xT5dSVk+ZTp44qvxUt9c+brrXyenFfs4u9zM33XtS2C/3h45UAwBAwzZyWO6VSe6ttd6XJKWUW5O8McnaQH1jkl8d3X5/kn9ZSimj9bfWWpeS/FUp5d7R831yPMNvz/6F6bz11d+Z775yT/63j/1F7vrmuS0/13KSo/3k6Kkkp2qS+hSPGIfxvjXOVJK5kpQkg5KkJsOazHSTXi+Zqkl3Jul2k34/mZ5K8lhTZ3XbTmc1QIfD1cd3uiW11sefq1+T6e7qNiXJueWkN/pc07OrQV1rzXC4ukEd1nR7q+tLeSz86yh0S8oo4evo2/14yNeaOnxy3A+zOsjS6aSU5PSDg8zd9aEntlnzXH/teUp58rp1Pt+Txpg124y+rsfXZc3jHlvurH69T3pceWJdp5Tk/Oc5b5t1n/v8cXfWPO5C6y723J0nb7P2DzAb+nzrPfcG1o1rm/PXPbacNctnHhpm/u4Pr/7xaJ0/Fj3VH53O32Y4fGK/sJnHXWyb88c+7u9TZxtj2ur36el47nGO6ezDwyx87sPPmq930o/bSWM689Ag85/9cFNjavH79Gwc0+kHB5n77IeaGlOL36dn45ieszCfq8uZvPT65exfWHMV1x1iI4F6MMk31ywfSvKq9baptfZLKSeSXD5a/6nzHntwy6PdIaZ7nfzwd1+Zl119Rf7sL7+VP/7iw/nC4Ydy34lJj+yZt5LkxGP/K13T12cGWT3pO1kt8U3ZRKifWrvthW5vNfrPf9yakwIe2uiFsDb6uZ+JP0zwjHngkjyBhHE4Ym6wjgfNDdbxoItv8tfd8+DqeZzlI1/KP33d9+64SG3ihY2llLcledto8XQp5SuTHM8GXJHk6MY3LyWldMrUzGxnZn5v6U3PppRuKZ1OUjrplG5Kt1tKebrOOy3ZWOVcaLsy+reuuc06hsNhOh2nD3Nh5gfrMTdYj7nBeswNLmY4HNbfes/Sid84+fChDFY2fTjoGfCC9e7YSKAeTnL1muWrRusutM2hUkovyd6sXixpI49NrfU9Sd6zgbE0oZTy6Vrr9ZMeB+0xN7gY84P1mBusx9xgPeYGF7OT58dG/uxyZ5LrSinXllKms3rRo9vP2+b2JDeNbv9Uko/VWuto/ZtKKTOllGuTXJfkP41n6AAAAFxKnvII6ug1pW9P8pGsvs3Me2utXyql/FqST9dab0/yr5P8m7J6EaRHshqxGW33f2T1gkr9JP/oUr2CLwAAANuzodeg1lo/mOSD56375TW3F5P8nXUe+xtJfmMbY2zRjjkdmWecucHFmB+sx9xgPeYG6zE3uJgdOz9Kra4QCgAAwOS59BcAAABNEKibVEp5fSnlK6WUe0spvzTp8fD0K6VcXUr501LKl0spXyql/OPR+stKKR8tpXx19O/+0fpSSvkXozny+VLK9615rptG23+1lHLTep+TnaWU0i2lfLaU8h9Gy9eWUv58NAduG11gLqMLxt02Wv/npZRr1jzHO0brv1JKed2EvhTGqJSyr5Ty/lLKX5RS7iml/A37DZKklPJPRr9PvlhK+aNSyqz9xrNXKeW9pZSHSilfXLNubPuKUsr3l1K+MHrMvyileNvAHWKduXHz6PfK50sp/1cpZd+a+y64T1ivX9bb70yaQN2EUko3ye8k+ZtJvifJz5RSvmeyo+IZ0E/yT2ut35PkB5L8o9HP/ZeS/Emt9bokfzJaTlbnx3Wjj7cl+d1k9ZdNkl9J8qokr0zyK4/9wmHH+8dJ7lmz/FtJ3lVr/Y4kjyZ5y2j9W5I8Olr/rtF2Gc2nNyX53iSvT/K/j/Y37Gz/PMmHa60vSvLSrM4R+41nuVLKwSS/kOT6WuuLs3oByjfFfuPZ7A+y+jNca5z7it9N8nNrHnf+56Jdf5C//vP6aJIX11r/syR/meQdyfr7hKfol/X2OxMlUDfnlUnurbXeV2tdTnJrkjdOeEw8zWqtR2qtd41un8rqfzIPZvVn/77RZu9L8uOj229Mcktd9akk+0opz0/yuiQfrbU+Umt9NKs7GL8kdrhSylVJ/sskvzdaLklek+T9o03OnxuPzZn3J/mR0fZvTHJrrXWp1vpXSe7N6v6GHaqUsjfJDVm9yn1qrcu11uOx32BVL8musvre8XNJjsR+41mr1vrxrL4Lxlpj2VeM7ttTa/3U6C0gb1nzXDTuQnOj1vofa6390eKnklw1ur3ePuGC/fIU/1+ZKIG6OQeTfHPN8qHROp4lRqdWvTzJnyc5UGs9MrrrgSQHRrfXmyfmz6Xpf03y3ycZjpYvT3J8zS+PtT/nx+fA6P4To+3NjUvPtUkeTvL7ZfX0798rpczHfuNZr9Z6OMn/nOQbWQ3TE0k+E/sNnmxc+4qDo9vnr+fS8A+TfGh0e7Nz42L/X5kogQobVEpZSPJ/Jvlva60n1943+qukS2I/y5RS/naSh2qtn5n0WGhOL8n3JfndWuvLk5zJE6foJbHfeLYanXb5xqz+EePKJPNxVJyLsK/gQkop78zqy9D+7aTHMm4CdXMOJ7l6zfJVo3Vc4kopU1mN039ba/3AaPWDo1NnMvr3odH69eaJ+XPp+cEkbyil3J/VU2Zek9XXHe4bnbqXPPnn/PgcGN2/N8mxmBuXokNJDtVa/3y0/P6sBqv9Bj+a5K9qrQ/XWleSfCCr+xL7DdYa177icJ44BXTtenawUsqbk/ztJH+3PvGeoZudG8ey/n5nogTq5tyZ5LrRFa+ms/pC5NsnPCaeZqNz9P91kntqrf/LmrtuT/LYVfJuSvJ/r1n/D0ZX2vuBJCdGp+l8JMlrSyn7R39Bf+1oHTtUrfUdtdaraq3XZHV/8LFa699N8qdJfmq02flz47E581Oj7eto/ZtGV+u8NqsXsfhPz9CXwdOg1vpAkm+WUr5rtOpHknw59husntr7A6WUudHvl8fmhv0Ga41lXzG672Qp5QdG8+0frHkudqBSyuuz+tKiN9Raz665a719wgX7ZbQfWW+/M1m1Vh+b+Ejyt7J6xayvJXnnpMfj4xn5mf8XWT215vNJ7h59/K2snrv/J0m+muSPk1w22r5k9WppX0vyhaxeqfGx5/qHWX3R+r1JfnbSX5uPsc6TG5P8h9HtF2b1l8K9Sf5dkpnR+tnR8r2j+1+45vHvHM2ZryT5m5P+enyMZU68LMmnR/uOf59kv/2Gj9HP9H9M8hdJvpjk3ySZsd949n4k+aOsvh55JatnX7xlnPuKJNeP5trXkvzLJGXSX7OPbc2Ne7P6mtLH/k/67jXbX3CfkHX6Zb39zqQ/ymhwAAAAMFFO8QUAAKAJAhUAAIAmCFQAAACaIFABAABogkAFAACgCQIVAACAJghUAAAAmiBQAQAAaML/D5sqLTq096bAAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 1152x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "D.plot_progress()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9929034113883972\n",
      "0.996117115020752\n",
      "0.9969319105148315\n",
      "0.9939669370651245\n"
     ]
    }
   ],
   "source": [
    "for i in range(4):\n",
    "    image_data_tensor = mnist_dataset[random.randint(0, 60000)][1]\n",
    "    print(D.forward(image_data_tensor).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0073259370401501656\n",
      "0.0068556466139853\n",
      "0.0058015501126646996\n",
      "0.004600544925779104\n"
     ]
    }
   ],
   "source": [
    "for i in range(4):\n",
    "    print(D.forward(generate_random(784)).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('ai')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "0cb0046d74bfb5a9fee6ebb85b9850d762e40d6b321d5d9a0aa0313eefa3c5b4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
